import torch
import torch.nn as nn
import torch.nn.functional as F
from search.neuron_operations import *
from spikingjelly.activation_based import layer

__all__ = [
    'super_resnet18',
    'super_resnet34',
    'super_vgg11',
    'super_vgg13',
    'super_vgg16',
    'super_vgg19',
]

operations = {
    'LIF': LIF,
    'LLIF': LLIF,
    'SLIF': SLIF,
}

cfg = {
    'VGG11': [
        [64, 'M'],
        [128, 'M'],
        [256, 256, 'M'],
        [512, 512, 'M'],
        [512, 512, 'M']
    ],
    'VGG13': [
        [64, 64, 'M'],
        [128, 128, 'M'],
        [256, 256, 'M'],
        [512, 512, 'M'],
        [512, 512, 'M']
    ],
    'VGG16': [
        [64, 64, 'M'],
        [128, 128, 'M'],
        [256, 256, 256, 'M'],
        [512, 512, 512, 'M'],
        [512, 512, 512, 'M']
    ],
    'VGG19': [
        [64, 64, 'M'],
        [128, 128, 'M'],
        [256, 256, 256, 256, 'M'],
        [512, 512, 512, 512, 'M'],
        [512, 512, 512, 512, 'M']
    ]
}

''' 
         scalable neuron operations
         extra neuron operations' names should be added here, and should be detailed defined in neuron_operations.py
'''

OP = ['LIF', 'LLIF', 'SLIF']


class super_neuron(nn.Module):
    def __init__(self, **kwargs):
        super(super_neuron, self).__init__()
        self._ops = nn.ModuleList()
        for op_name in OP:
            op = operations[op_name](**kwargs)
            self._ops.append(op)

    def forward(self, x, alphas):
        return sum(w * op(x) for w, op in zip(alphas, self._ops))

class resnet_block(nn.Module):
    expansion = 1
    # out_ex = 4 
    def __init__(self, in_channels, out_channels, stride, dropout=0., downsample=None, **kwargs):
        super(resnet_block, self).__init__()
        self.downsample = downsample
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.super_neuron1 = super_neuron(**kwargs)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=True)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.super_neuron2 = super_neuron(**kwargs)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.dropout = nn.Dropout(dropout)

        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, padding=0, bias=True)
        else:
            self.shortcut = nn.Sequential()

    def forward(self, x, alphas):
        out = self.bn1(x)
        out = self.super_neuron1(out, alphas[0])
        indentity = out
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.super_neuron2(out, alphas[1])
        out = self.conv2(out)
        out = self.dropout(out)
        out += self.shortcut(indentity)
        return out

# class CustomSequential(nn.Module):
#     def __init__(self, *args):
#         super(CustomSequential, self).__init__()
#         self.blocks = nn.ModuleList(args)
#
#     def forward(self, x, alphas):
#         for block in self.blocks:
#             x = block(x, alphas)
#         return x

class super_resnet(nn.Module):
    def __init__(self, block=resnet_block, layers_num=[2, 2, 2, 2], num_classes=10, dropout=0., **kwargs):
        super(super_resnet, self).__init__()
        self.data_channels = kwargs.get('c_in', 3)
        self.init_channels = 64
        self.total_layers = sum(layers_num)
        self.conv1 = nn.Conv2d(self.data_channels, self.init_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self._init_archi_alphas()
        self.blocks = nn.ModuleList()

        # here *_make_layer* does not return anything, it just append the layers to self.blocks
        self.layer1 = self._make_layer(block, 64, layers_num[0], 1, dropout, **kwargs)
        self.layer2 = self._make_layer(block, 128, layers_num[1], 2, dropout, **kwargs)
        self.layer3 = self._make_layer(block, 256, layers_num[2], 2, dropout, **kwargs)
        self.layer4 = self._make_layer(block, 512, layers_num[3], 2, dropout, **kwargs)
        self.bn1 = nn.BatchNorm2d(512 * block.expansion)
        self.pool = nn.AvgPool2d(4)
        self.flat = nn.Flatten()
        self.drop = layer.Dropout(dropout)

        self.linear = nn.Linear(512 * block.expansion, num_classes)
        # self.linear = nn.Linear(512 * block.out_ex, num_classes)

        self.sn1 = super_neuron(**kwargs)

    def _make_layer(self, block, out_channels, num_blocks, stride, dropout, **kwargs):
        strides = [stride] + [1] * (num_blocks - 1)
        for stride in strides:
            self.blocks.append(block(self.init_channels, out_channels, stride, dropout, **kwargs))
            self.init_channels = out_channels * block.expansion


    def _init_archi_alphas(self):
        num_ops = len(OP)
        # make weights other initializations is possible
        self._alphas_archi = nn.Parameter(1e-3 * torch.randn(1 + 2 * self.total_layers, num_ops), requires_grad=True)
        self._alphas = [self._alphas_archi]

    def alphas(self):
        return self._alphas

    def alphas_tensor(self):
        return self._alphas[0]

    def weights(self):
        return self.parameters()

    def _forward_implicit(self, x):
        out = self.conv1(x)
        alphas = self.alphas_tensor()
        alphas = F.softmax(alphas, dim=1)
        for i, block in enumerate(self.blocks):
            out = block(out, alphas[i*2:(i+1)*2, :])
        out = self.pool(self.sn1(self.bn1(out), alphas[-1]))
        out = self.drop(self.flat(out))
        out = self.linear(out)
        return out

    def forward(self, x):
        return self._forward_implicit(x)

    def loss(self, x, y):
        # used to calculate w', virtual step, not for real loss
        return F.cross_entropy(self.forward(x), y)


class VGG_Sequential(nn.Sequential):
    # for VGG, the forward function is different from resnet
    def forward(self, x, alphas):
        alpha_index = 0
        for module in self:
            if isinstance(module, super_neuron):
                x = module(x, alphas[alpha_index])
                alpha_index += 1
            else:
                x = module(x)
        return x

class super_vgg(nn.Module):
    def __init__(self, vgg_name, dropout=0.0, num_classes=10, **kwargs):
        super(super_vgg, self).__init__()
        self.whether_bias = True
        self.init_channels = kwargs.get('c_in', 2)
        self.layers = self._count_layers(cfg[vgg_name])
        self.total_layers = self.layers[-1]
        self._init_archi_alphas()
        self.layer1 = self._make_layers(cfg[vgg_name][0], dropout, **kwargs)
        self.layer2 = self._make_layers(cfg[vgg_name][1], dropout, **kwargs)
        self.layer3 = self._make_layers(cfg[vgg_name][2], dropout, **kwargs)
        self.layer4 = self._make_layers(cfg[vgg_name][3], dropout, **kwargs)
        self.layer5 = self._make_layers(cfg[vgg_name][4], dropout, **kwargs)
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 7 * 7, num_classes),
        )
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def _make_layers(self, cfg, dropout, **kwargs):
        layers = []
        for x in cfg:
            if x == 'M':
                layers.append(nn.AvgPool2d(kernel_size=2, stride=2))
            else:
                layers.append(nn.Conv2d(self.init_channels, x, kernel_size=3, padding=1, bias=self.whether_bias))
                layers.append(nn.BatchNorm2d(x))
                layers.append(super_neuron(**kwargs))
                layers.append(layer.Dropout(dropout))
                self.init_channels = x
        return  VGG_Sequential(*layers)

    def _forward_implicit(self, x):
        alphas = self.alphas_tensor()
        alphas = F.softmax(alphas, dim=1)
        out = self.layer1(x, alphas[0:self.layers[0], :])
        out = self.layer2(out, alphas[self.layers[0]:self.layers[1], :])
        out = self.layer3(out, alphas[self.layers[1]:self.layers[2], :])
        out = self.layer4(out, alphas[self.layers[2]:self.layers[3], :])
        out = self.layer5(out, alphas[self.layers[3]:self.layers[4], :])
        out = self.avgpool(out)
        out = self.classifier(out)
        return out

    def forward(self, x):
        return self._forward_implicit(x)

    def _count_layers(self, cfg):
        cumulative_counts = []
        count = 0
        for layer in cfg:
            # Count non-string elements in the current layer
            count += sum(1 for x in layer if not isinstance(x, str))
            cumulative_counts.append(count)
        return cumulative_counts

    def _init_archi_alphas(self):
        num_ops = len(OP)
        # make weights other initializations is possible
        self._alphas_archi = nn.Parameter(1e-3 * torch.randn(self.total_layers, num_ops), requires_grad=True)
        self._alphas = [self._alphas_archi]

    def alphas(self):
        return self._alphas

    def alphas_tensor(self):
        return self._alphas[0]

    def weights(self):
        return self.parameters()

    def loss(self, x, y):
        # used to calculate w', virtual step, not for real loss
        return F.cross_entropy(self.forward(x), y)

def _super_model_resnet(block, layers_num, num_classes, **kwargs):
    return super_resnet(block, layers_num, num_classes, **kwargs)

def _super_model_vgg(vgg_name, **kwargs):
    return super_vgg(vgg_name, **kwargs)

def super_resnet18(num_classes, **kwargs):
    return _super_model_resnet(resnet_block, [2, 2, 2, 2], num_classes=num_classes, **kwargs)

def super_resnet34(num_classes, **kwargs):
    return _super_model_resnet(resnet_block, [3, 4, 6, 3], num_classes=num_classes, **kwargs)

def super_vgg11(num_classes=10, neuron_dropout=0.0, **kwargs):
    return super_vgg('VGG11', dropout=neuron_dropout, num_classes=num_classes, **kwargs)

def super_vgg13(num_classes=10, neuron_dropout=0.0, **kwargs):
    return super_vgg('VGG13', dropout=neuron_dropout, num_classes=num_classes, **kwargs)

def super_vgg16(num_classes=10, neuron_dropout=0.0, **kwargs):
    return super_vgg('VGG16', dropout=neuron_dropout, num_classes=num_classes, **kwargs)

def super_vgg19(num_classes=10, neuron_dropout=0.0, **kwargs):
    return super_vgg('VGG19', dropout=neuron_dropout, num_classes=num_classes, **kwargs)



# test the model
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # super_model = super_vgg11().to(device)
    super_model = super_resnet18(num_classes=200).to(device)
    print(super_model)

    x = torch.rand([1, 3, 64, 64], device=device)
    # x = torch.rand([1, 3, 32, 32], device=device)
    x = (x >= 0.5).float()
    print(super_model(x))


